# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""MIDAS (Multiple Imputation with Denoising Autoencoders).

The MIDAS implementation can be found at:
  https://github.com/ranjitlall/MIDAS

This currently only works with Tensorflow 1.X. To run follow the next steps:
  1. Create virtual environment and activate it.
  2. Fetch the latest version of legacy Tensorflow:
    > pip3 install --user tensorflow==1.15
  3. Install MIDAS off the github:
    > pip3 install --user git+https://github.com/ranjitlall/MIDAS.git

References:
-----------
  [1] Lall, Ranjit and Robinson, Thomas (2020): "Applying the MIDAS Touch: How
      to Handle Missing Values in Large and Complex Data".
  [2] Lall, Ranjit and Robinson, Thomas (2020): "Applying the MIDAS Touch: An
      Accurate and Scalable Approach to Imputing Missing Data".

Example:
--------
(1) Training and imputation:
  Following will train MIDAS model using TensorFlow taking the data from
  ${DATASET_DIR}, generating the model in ${MODEL_DIR} and producing the
  datasets with imputed features under ${IMPUTED_DATASETS_DIR}. The datasets
  in ${DATASET_DIR} need to be generated by passing "--categorical_as_ints"
  option to "sigtyp_reader" converter.

  > python3 midas_imputer_main.py \
      --input_dir ${DATASET_DIR} \
      --num_epochs 10 \
      --model_dir ${MODEL_DIR} \
      --output_datasets_dir ${IMPUTED_DATASETS_DIR}

(2) Imputation from existing model.

  > MODEL_DIR=models/midas/train/1/midas
  > python3 midas_imputer_main.py \
      --input_dir ${DATASET_DIR} \
      --notrain \
      --model_dir ${MODEL_DIR} \
      --output_datasets_dir ${IMPUTED_DATASETS_DIR}

TODO:
-----
  - Figure out how to evaluate on data not seen during the training.
  - Tune hyper-parameters on the dev set.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os

from absl import app
from absl import flags
from absl import logging

import constants as const
import data_info as data_lib
from midas import Midas
import numpy as np
import pandas as pd

flags.DEFINE_string(
    "input_dir", "",
    "Input directory for preprocessed SIGTYP files.")

flags.DEFINE_string(
    "training_set_name", const.TRAIN_FILENAME,
    "Name of the training set. Can be \"train\" or \"train_dev\".")

flags.DEFINE_integer(
    "num_epochs", 10,
    "Number of epochs for training.")

flags.DEFINE_integer(
    "batch_size", 8,
    "Number of examples per batch.")

flags.DEFINE_list(
    "layer_structure", ["16", "16"],
    "Number of nodes per layer.")

flags.DEFINE_string(
    "model_dir", "/tmp/MIDAS",
    "Directory where to store or load the model.")

flags.DEFINE_string(
    "output_datasets_dir", "",
    "Directory for storing datasets with imputed values.")

flags.DEFINE_integer(
    "num_datasets", 5,
    "Number of imputed datasets to generate.")

flags.DEFINE_boolean(
    "save_softmax_columns", False,
    "Saves the candidate datasets with all the softmax columns for categorical "
    "data. If disabled, the softmax columns are converted back to the normal "
    "format.")

flags.DEFINE_bool(
    "train", True,
    ("Train the model. If disabled, the model will be restored from the "
     "checkpoints located in directory specified by --model_dir flag."))

flags.DEFINE_boolean(
    "categorical_as_ints", False,
    ("If enabled, the resulting categorical columns are all converted back to "
     "strings."))

FLAGS = flags.FLAGS

# For sanity checks.
_EPSILON = 1E-5


def _load_data_info():
  """Loads data info mappings."""
  input_filename = os.path.join(
      FLAGS.input_dir,
      (const.DATA_INFO_FILENAME + "_" + FLAGS.training_set_name +
       data_lib.FILE_EXTENSION))
  return data_lib.load_data_info(input_filename)


def _encode_categorical(df):
  """Encodes categorical columns."""
  logging.info("Encoding categorical columns ...")
  categorical = [
      col for col in df.columns if col not in ["latitude", "longitude"]]
  cat_df = df[categorical]
  df.drop(categorical, axis=1, inplace=True)

  constructor_list = [df]
  columns_list = []
  for column in cat_df.columns:
    # Convert each categorical column for variable V into N indicator columns,
    # i.e. turn it into a one-hot representation, where names of the new columns
    # are "V_i", i \in [1,N].
    na_temp = cat_df[column].isnull()
    temp = pd.get_dummies(cat_df[column], prefix=column)
    temp[na_temp] = np.nan
    constructor_list.append(temp)
    columns_list.append(list(temp.columns.values))

  df = pd.concat(constructor_list, axis=1)
  na_loc = df.isnull()
  df[na_loc] = np.nan
  logging.info("Encoded %d categorical columns.", len(columns_list))
  return columns_list, df


def _load_df(file_type):
  """Loads dataframe of a specified type."""
  # Read the database.
  input_file = os.path.join(FLAGS.input_dir, file_type + ".csv")
  logging.info("Reading \"%s\" ...", input_file)
  original_df = pd.read_csv(input_file, sep="|", encoding=const.ENCODING)
  logging.info("Read %d entries", original_df.shape[0])
  print("Summary: {}".format(original_df))
  print("Types: {}".format(original_df.dtypes))

  # Drop some of the non-interesting features for now.
  exclude_columns = ["wals_code", "name", "countrycodes"]
  df = original_df.copy()
  df = df.drop(exclude_columns, axis=1)
  categorical_columns, df = _encode_categorical(df)
  logging.info("Data shape: %s", df.shape)
  return original_df, categorical_columns, df


def _convert_from_softmax(original_df, df, data_info):
  """Converts categorical softmax columns to normal dense format.

  This will convert all the values in softmax categorical columns of `df` to
  their regular dense equivalents in `original_df`.

  Args:
    original_df: (dataframe) Pandas dataframe representing original dataset.
    df: (dataframe) Pandas dataframe representing the dataset with softmax
        columns.
    data_info: (dict) Dictionary containing information about the features.

  Returns:
    Pandas dataframe object representing the updated dataset.
  """
  good_columns = ["wals_code", "name", "latitude", "longitude",
                  "genus", "family", "countrycodes"]
  predict_categorical = [
      col for col in original_df.columns if col not in good_columns]

  update_df = original_df.copy()  # Important to clone the original.
  num_updated_values = 0
  for index, _ in update_df.iterrows():
    for col_name in predict_categorical:
      # Accumulate all the probabilities for a given column.
      num_values = len(data_info[const.DATA_KEY_FEATURES][col_name]["values"])
      assert num_values > 0, "%s: Bad number of feature values!" % col_name
      probs = []
      for i in range(num_values):
        # For original column C, the softmax column $i$ is named
        # "C_i.0" (where i is treated as float).
        softmax_col_name = col_name + "_" + str(i + 1) + ".0"
        probs.append(float(df.at[index, softmax_col_name]))
      accum = np.sum(probs)
      assert abs(accum - 1.0) < _EPSILON, (
          "%s: Probs don't sum to one (sum = %f)!" % (col_name, accum))

      # Replace the missing value with the value from the most probable column.
      if pd.isnull(update_df.at[index, col_name]):
        max_prob_value = np.argmax(probs) + 1
        update_df.at[index, col_name] = max_prob_value
        num_updated_values += 1
  logging.info("Number of values filled in: %d", num_updated_values)

  # For some reason the categorical columns are floats. Coerce them back to
  # integers.
  for col_name in predict_categorical:
    update_df[col_name] = update_df[col_name].astype(int)
  return update_df


def _convert_categorical(df, data_info):
  """Converts all the categorical columns back to strings."""
  good_cols = ["wals_code", "name", "latitude", "longitude",
               "genus", "family", "countrycodes"]
  feature_cols = [col for col in df.columns if col not in good_cols]

  # Convert genus and family separately.
  genera = []
  families = []
  for index, _ in df.iterrows():
    value = df.at[index, "genus"]
    assert value > 0, "genus: Invalid value: %d" % value
    value_list = data_info[const.DATA_KEY_GENERA]["values"]
    genera.append(value_list[value - 1])
    value = df.at[index, "family"]
    assert value > 0, "family: Invalid value: %d" % value
    value_list = data_info[const.DATA_KEY_FAMILIES]["values"]
    families.append(value_list[value - 1])
  df["genus"] = genera
  df["family"] = families

  # Process actual features.
  for col_name in feature_cols:
    col_values = []
    for index, _ in df.iterrows():
      value = df.at[index, col_name]
      assert value > 0, "%s: Invalid value: %d" % (col_name, value)
      value_list = data_info[const.DATA_KEY_FEATURES][col_name]["values"]
      col_values.append(value_list[value - 1])
    df[col_name] = col_values

  return df


def main(unused_args):
  if not FLAGS.input_dir:
    raise ValueError("Specify --input_dir!")
  original_df, categorical_columns, df = _load_df(FLAGS.training_set_name)
  data_info = _load_data_info()

  # Construct the graph.
  layer_structure = [int(layer) for layer in FLAGS.layer_structure]
  logging.info("Constructing model ...")
  logging.info("Parameters: layers: %s, batch_size: %d",
               layer_structure, FLAGS.batch_size)
  imputer = Midas(layer_structure=layer_structure,
                  train_batch=FLAGS.batch_size,
                  savepath=FLAGS.model_dir,
                  vae_layer=False, seed=89,
                  input_drop=0.75)
  imputer.build_model(df, softmax_columns=categorical_columns)

  # Train.
  if FLAGS.train:
    logging.info("Training model (%d epochs) ...", FLAGS.num_epochs)
    imputer.train_model(training_epochs=FLAGS.num_epochs)
    logging.info("Model saved to \"%s\".", FLAGS.model_dir)
  else:
    logging.info("No training mode. Model will be restored from \"%s\" ...",
                 FLAGS.model_dir)

  # Generate imputed datasets.
  if FLAGS.output_datasets_dir:
    logging.info("Saving %d datasets to \"%s\" ...",
                 FLAGS.num_datasets, FLAGS.output_datasets_dir)
    imputations = imputer.generate_samples(m=FLAGS.num_datasets).output_list
    n = 0
    for df in imputations:
      file_path = os.path.join(FLAGS.output_datasets_dir, "midas_%d.csv" % n)
      if not FLAGS.save_softmax_columns:
        logging.info("[%d] Converting data from softmax ...", n)
        df = _convert_from_softmax(original_df, df, data_info)
        if not FLAGS.categorical_as_ints:
          logging.info("[%d] Converting categoricals back to strings ...", n)
          df = _convert_categorical(df, data_info)
      logging.info("[%d] Saving dataset to \"%s\" ...", n, file_path)
      df.to_csv(file_path, sep="|", index=False)
      n += 1


if __name__ == "__main__":
  app.run(main)
